{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 数据输入（二） image 数据\n",
    "\n",
    "上个例子中学习了 numpy 数据的读取方式，已经可以解决不少问题了。现在还有两个问题可能需要用到不同的数据打包方式，一个是图片数据的读取；另一个就是变长序列的读取。图片的例子不少，但是关于变长序列我还没找到很好的解决方式。\n",
    "\n",
    "本文主要参考：\n",
    "- [TensorFlow全新的数据读取方式：Dataset API入门教程](https://zhuanlan.zhihu.com/p/30751039)\n",
    "- https://github.com/yongyehuang/TensorFlow-Examples/blob/master/examples/5_DataManagement/build_an_image_dataset.py"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import warnings\n",
    "warnings.filterwarnings('ignore')  # 不打印 warning \n",
    "\n",
    "import tensorflow as tf\n",
    "\n",
    "# 设置GPU按需增长\n",
    "config = tf.ConfigProto()\n",
    "config.gpu_options.allow_growth = True\n",
    "sess = tf.Session(config=config)\n",
    "\n",
    "import numpy as np\n",
    "import sys\n",
    "import os\n",
    "import time"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "```\n",
    "在 data/sketchy_000000000000/ 目录下一共有 125 个目录，每个目录表示一个类别，每个目录下有多张图片。\n",
    "airplane/\n",
    "    ***1.png\n",
    "    ***2.png\n",
    "    ...\n",
    "ant/ \n",
    "    ***1.png\n",
    "    ***2.png\n",
    "    ...\n",
    "...\n",
    "   \n",
    "```\n",
    "\n",
    "- step1: 获取每张图片的文件名和对应的标签。\n",
    "- step2: 使用 dataset.map 函数解析图片。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "75481\n",
      "75481\n",
      "../data/sketchy_000000000000/airplane/n02691156_10151-1.png\n"
     ]
    }
   ],
   "source": [
    "def get_file_path(data_path = '../data/sketchy_000000000000/'):\n",
    "    \"\"\"解析文件夹，获取每个文件的路径和标签。\"\"\"\n",
    "    img_paths = list()\n",
    "    labels = list()\n",
    "    class_dirs = sorted(os.listdir(data_path))\n",
    "    dict_class2id = dict()\n",
    "    for i in range(len(class_dirs)):\n",
    "        label = i\n",
    "        class_dir = class_dirs[i]\n",
    "        dict_class2id[class_dir] = label\n",
    "        class_path = os.path.join(data_path, class_dir)  # 每类的路径\n",
    "        file_names = sorted(os.listdir(class_path))\n",
    "        for file_name in file_names:\n",
    "            file_path = os.path.join(class_path, file_name)\n",
    "            img_paths.append(file_path)\n",
    "            labels.append(label)\n",
    "    return img_paths, labels\n",
    "\n",
    "img_paths, labels = get_file_path()\n",
    "print(len(img_paths))\n",
    "print(len(labels))\n",
    "img0 = img_paths[0]\n",
    "print(img0)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "构造 dataset 并读取数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def parse_png(img_path, label, height=256, width=256, channel=3):\n",
    "    \"\"\"根据 img_path 读入图片并做相应处理\"\"\"\n",
    "    # 从硬盘上读取图片\n",
    "    img = tf.read_file(img_path)\n",
    "    img_decoded = tf.image.decode_png(img, channels=channel)\n",
    "    # resize\n",
    "    img_resized = tf.image.resize_images(img_decoded, [height, width])\n",
    "    # normalize \n",
    "    img_norm = img_resized * 1.0 / 127.5 - 1.0\n",
    "    return img_norm, label\n",
    "\n",
    "dataset = tf.data.Dataset.from_tensor_slices((img_paths, labels))\n",
    "dataset = dataset.map(parse_png)\n",
    "print('parsing image', dataset)\n",
    "dataset = dataset.shuffle(buffer_size=5000).repeat().batch(256)\n",
    "print('batch', dataset)\n",
    "\n",
    "# 生成迭代器\n",
    "iterator = dataset.make_one_shot_iterator()\n",
    "print(iterator)\n",
    "\n",
    "time0 = time.time()\n",
    "for count in range(1000):\n",
    "    X_batch, y_batch = sess.run(iterator.get_next())\n",
    "#     print('count = {} : X.shape = {}, y[:10] = {}, pass {}s'.format(count, X_batch.shape, y_batch[:10], time.time() - time0))\n",
    "#     time0 = time.time()\n",
    "print(time.time() - time0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "###  batch_size=4 ，no shuffle\n",
    "\n",
    "```\n",
    "parsing image <MapDataset shapes: ((256, 256, 3), ()), types: (tf.float32, tf.int32)>\n",
    "batch <BatchDataset shapes: ((?, 256, 256, 3), (?,)), types: (tf.float32, tf.int32)>\n",
    "<tensorflow.python.data.ops.iterator_ops.Iterator object at 0x7f0bd00e05c0>\n",
    "count = 0 : X.shape = (4, 256, 256, 3), y[:10] = [0 0 0 0], pass 0.4704289436340332s\n",
    "count = 1 : X.shape = (4, 256, 256, 3), y[:10] = [0 0 0 0], pass 0.38211870193481445s\n",
    "count = 2 : X.shape = (4, 256, 256, 3), y[:10] = [0 0 0 0], pass 0.4049665927886963s\n",
    "count = 3 : X.shape = (4, 256, 256, 3), y[:10] = [0 0 0 0], pass 0.38713955879211426s\n",
    "count = 4 : X.shape = (4, 256, 256, 3), y[:10] = [0 0 0 0], pass 0.3735225200653076s\n",
    "count = 5 : X.shape = (4, 256, 256, 3), y[:10] = [0 0 0 0], pass 0.37211036682128906s\n",
    "count = 6 : X.shape = (4, 256, 256, 3), y[:10] = [0 0 0 0], pass 0.36931586265563965s\n",
    "count = 7 : X.shape = (4, 256, 256, 3), y[:10] = [0 0 0 0], pass 0.36876487731933594s\n",
    "count = 8 : X.shape = (4, 256, 256, 3), y[:10] = [0 0 0 0], pass 0.3712954521179199s\n",
    "count = 9 : X.shape = (4, 256, 256, 3), y[:10] = [0 0 0 0], pass 0.38353633880615234s\n",
    "```\n",
    "\n",
    "\n",
    "###  batch_size=4 ，buffer_size=1000\n",
    "\n",
    "```\n",
    "parsing image <MapDataset shapes: ((256, 256, 3), ()), types: (tf.float32, tf.int32)>\n",
    "batch <BatchDataset shapes: ((?, 256, 256, 3), (?,)), types: (tf.float32, tf.int32)>\n",
    "<tensorflow.python.data.ops.iterator_ops.Iterator object at 0x7f81c9faba90>\n",
    "count = 0 : X.shape = (4, 256, 256, 3), y = [0 0 0 0], pass 2.277726411819458s\n",
    "count = 1 : X.shape = (4, 256, 256, 3), y = [0 0 1 1], pass 0.3655529022216797s\n",
    "count = 2 : X.shape = (4, 256, 256, 3), y = [0 1 0 0], pass 0.3675987720489502s\n",
    "count = 3 : X.shape = (4, 256, 256, 3), y = [1 0 0 0], pass 0.3594975471496582s\n",
    "count = 4 : X.shape = (4, 256, 256, 3), y = [0 0 0 0], pass 0.35462403297424316s\n",
    "count = 5 : X.shape = (4, 256, 256, 3), y = [0 0 0 1], pass 0.3697807788848877s\n",
    "count = 6 : X.shape = (4, 256, 256, 3), y = [0 0 0 0], pass 0.37018513679504395s\n",
    "count = 7 : X.shape = (4, 256, 256, 3), y = [1 0 0 0], pass 0.3580958843231201s\n",
    "count = 8 : X.shape = (4, 256, 256, 3), y = [0 0 1 1], pass 0.35622262954711914s\n",
    "count = 9 : X.shape = (4, 256, 256, 3), y = [0 1 0 0], pass 0.35665392875671387s\n",
    "```\n",
    "###  batch_size=4 ，buffer_size=5000\n",
    "\n",
    "```\n",
    "parsing image <MapDataset shapes: ((256, 256, 3), ()), types: (tf.float32, tf.int32)>\n",
    "batch <BatchDataset shapes: ((?, 256, 256, 3), (?,)), types: (tf.float32, tf.int32)>\n",
    "<tensorflow.python.data.ops.iterator_ops.Iterator object at 0x7f8de8a718d0>\n",
    "count = 0 : X.shape = (4, 256, 256, 3), y = [3 6 4 4], pass 10.296564102172852s\n",
    "count = 1 : X.shape = (4, 256, 256, 3), y = [5 0 4 3], pass 0.43836045265197754s\n",
    "count = 2 : X.shape = (4, 256, 256, 3), y = [4 7 4 2], pass 0.4133310317993164s\n",
    "count = 3 : X.shape = (4, 256, 256, 3), y = [6 8 0 4], pass 0.37926673889160156s\n",
    "count = 4 : X.shape = (4, 256, 256, 3), y = [6 7 0 1], pass 0.41953492164611816s\n",
    "count = 5 : X.shape = (4, 256, 256, 3), y = [3 4 6 6], pass 0.39876246452331543s\n",
    "count = 6 : X.shape = (4, 256, 256, 3), y = [1 5 2 7], pass 0.39066076278686523s\n",
    "count = 7 : X.shape = (4, 256, 256, 3), y = [3 8 7 0], pass 0.4210829734802246s\n",
    "count = 8 : X.shape = (4, 256, 256, 3), y = [0 2 0 2], pass 0.39181971549987793s\n",
    "count = 9 : X.shape = (4, 256, 256, 3), y = [6 3 5 2], pass 0.3952939510345459s\n",
    "```\n",
    "\n",
    "###  batch_size=4 ，buffer_size=10000\n",
    "\n",
    "```\n",
    "parsing image <MapDataset shapes: ((256, 256, 3), ()), types: (tf.float32, tf.int32)>\n",
    "batch <BatchDataset shapes: ((?, 256, 256, 3), (?,)), types: (tf.float32, tf.int32)>\n",
    "<tensorflow.python.data.ops.iterator_ops.Iterator object at 0x7fe7971adb70>\n",
    "count = 0 : X.shape = (4, 256, 256, 3), y[:10] = [7 6 7 5], pass 20.642507076263428s\n",
    "count = 1 : X.shape = (4, 256, 256, 3), y[:10] = [ 5  1 13  5], pass 0.42054080963134766s\n",
    "count = 2 : X.shape = (4, 256, 256, 3), y[:10] = [0 8 3 9], pass 0.41022515296936035s\n",
    "count = 3 : X.shape = (4, 256, 256, 3), y[:10] = [5 3 0 4], pass 0.38398194313049316s\n",
    "count = 4 : X.shape = (4, 256, 256, 3), y[:10] = [ 3 13 10  2], pass 0.39125919342041016s\n",
    "count = 5 : X.shape = (4, 256, 256, 3), y[:10] = [ 9 14 11  3], pass 0.3871769905090332s\n",
    "count = 6 : X.shape = (4, 256, 256, 3), y[:10] = [ 0  3 12  1], pass 0.39249467849731445s\n",
    "count = 7 : X.shape = (4, 256, 256, 3), y[:10] = [10 11  4  7], pass 0.39815545082092285s\n",
    "count = 8 : X.shape = (4, 256, 256, 3), y[:10] = [6 3 3 4], pass 0.3916475772857666s\n",
    "count = 9 : X.shape = (4, 256, 256, 3), y[:10] = [ 7 13 13  5], pass 0.3907802104949951s\n",
    "```\n",
    "\n",
    "\n",
    "-----\n",
    "### batch_size=256， no shuffle \n",
    "\n",
    "```\n",
    "parsing image <MapDataset shapes: ((256, 256, 3), ()), types: (tf.float32, tf.int32)>\n",
    "batch <BatchDataset shapes: ((?, 256, 256, 3), (?,)), types: (tf.float32, tf.int32)>\n",
    "<tensorflow.python.data.ops.iterator_ops.Iterator object at 0x7fa731d6da58>\n",
    "count = 0 : X.shape = (256, 256, 256, 3), y[:10] = [0 1 0 0 0 0 0 1 0 0], pass 2.809150218963623s\n",
    "count = 1 : X.shape = (256, 256, 256, 3), y[:10] = [0 0 0 1 0 1 0 1 1 0], pass 0.9180774688720703s\n",
    "count = 2 : X.shape = (256, 256, 256, 3), y[:10] = [1 1 2 1 1 1 0 0 0 2], pass 0.9346837997436523s\n",
    "count = 3 : X.shape = (256, 256, 256, 3), y[:10] = [0 1 0 1 0 2 1 1 2 1], pass 0.9159815311431885s\n",
    "count = 4 : X.shape = (256, 256, 256, 3), y[:10] = [1 1 0 0 3 0 3 0 3 0], pass 0.9176328182220459s\n",
    "count = 5 : X.shape = (256, 256, 256, 3), y[:10] = [1 2 1 3 2 1 3 2 3 3], pass 0.9095683097839355s\n",
    "count = 6 : X.shape = (256, 256, 256, 3), y[:10] = [3 3 1 1 3 1 1 3 1 3], pass 0.9027786254882812s\n",
    "count = 7 : X.shape = (256, 256, 256, 3), y[:10] = [4 4 2 2 2 3 3 3 0 0], pass 0.8944077491760254s\n",
    "count = 8 : X.shape = (256, 256, 256, 3), y[:10] = [3 4 3 4 2 4 3 4 4 3], pass 0.9207658767700195s\n",
    "count = 9 : X.shape = (256, 256, 256, 3), y[:10] = [0 2 5 1 4 5 4 0 5 5], pass 0.9587688446044922s\n",
    "```\n",
    "\n",
    "### batch_size=256， buffer_size=1000 \n",
    "\n",
    "```\n",
    "parsing image <MapDataset shapes: ((256, 256, 3), ()), types: (tf.float32, tf.int32)>\n",
    "batch <BatchDataset shapes: ((?, 256, 256, 3), (?,)), types: (tf.float32, tf.int32)>\n",
    "<tensorflow.python.data.ops.iterator_ops.Iterator object at 0x7f6461269b38>\n",
    "count = 0 : X.shape = (256, 256, 256, 3), y[:10] = [0 1 1 0 0 0 1 1 1 0], pass 2.8092191219329834s\n",
    "count = 1 : X.shape = (256, 256, 256, 3), y[:10] = [0 0 1 0 1 0 1 1 0 0], pass 0.9225723743438721s\n",
    "count = 2 : X.shape = (256, 256, 256, 3), y[:10] = [0 0 0 0 0 0 1 1 2 2], pass 0.929398775100708s\n",
    "count = 3 : X.shape = (256, 256, 256, 3), y[:10] = [2 2 0 1 0 2 2 2 1 2], pass 0.9135322570800781s\n",
    "count = 4 : X.shape = (256, 256, 256, 3), y[:10] = [1 2 3 1 2 2 0 1 1 2], pass 0.9033732414245605s\n",
    "count = 5 : X.shape = (256, 256, 256, 3), y[:10] = [0 3 0 3 3 3 3 3 2 2], pass 0.8910338878631592s\n",
    "count = 6 : X.shape = (256, 256, 256, 3), y[:10] = [1 0 3 3 3 3 3 2 0 1], pass 0.9181504249572754s\n",
    "count = 7 : X.shape = (256, 256, 256, 3), y[:10] = [3 4 3 3 3 4 2 1 1 3], pass 0.8899593353271484s\n",
    "count = 8 : X.shape = (256, 256, 256, 3), y[:10] = [2 4 4 4 3 1 3 4 4 4], pass 0.8846793174743652s\n",
    "count = 9 : X.shape = (256, 256, 256, 3), y[:10] = [4 3 0 4 4 4 4 2 5 4], pass 0.8913383483886719s\n",
    "```\n",
    "\n",
    "### batch_size=256， buffer_size=5000 \n",
    "```\n",
    "batch <BatchDataset shapes: ((?, 256, 256, 3), (?,)), types: (tf.float32, tf.int32)>\n",
    "<tensorflow.python.data.ops.iterator_ops.Iterator object at 0x7f5d6ba2cb38>\n",
    "count = 0 : X.shape = (256, 256, 256, 3), y[:10] = [0 7 4 0 7 6 1 8 5 3], pass 11.391936540603638s\n",
    "count = 1 : X.shape = (256, 256, 256, 3), y[:10] = [6 7 5 1 6 2 7 4 1 6], pass 1.4594790935516357s\n",
    "count = 2 : X.shape = (256, 256, 256, 3), y[:10] = [3 0 2 4 6 6 7 2 5 4], pass 1.4608898162841797s\n",
    "count = 3 : X.shape = (256, 256, 256, 3), y[:10] = [7 9 1 1 5 6 8 3 5 1], pass 1.2903695106506348s\n",
    "count = 4 : X.shape = (256, 256, 256, 3), y[:10] = [5 4 0 6 8 9 8 3 3 8], pass 1.3749797344207764s\n",
    "count = 5 : X.shape = (256, 256, 256, 3), y[:10] = [2 2 3 4 8 6 1 8 4 7], pass 1.3508849143981934s\n",
    "count = 6 : X.shape = (256, 256, 256, 3), y[:10] = [8 7 7 9 0 1 0 6 2 6], pass 1.3386600017547607s\n",
    "count = 7 : X.shape = (256, 256, 256, 3), y[:10] = [10  7  0  4  8  6  8  2  1  1], pass 1.38677978515625s\n",
    "count = 8 : X.shape = (256, 256, 256, 3), y[:10] = [ 9  0  9  2  6  1  7  7  3 10], pass 1.3200161457061768s\n",
    "count = 9 : X.shape = (256, 256, 256, 3), y[:10] = [ 4  7  3  8  8  0 10  1  9  3], pass 1.3426642417907715s\n",
    "```\n",
    "\n",
    "### batch_size=256， buffer_size=10000 \n",
    "```\n",
    "parsing image <MapDataset shapes: ((256, 256, 3), ()), types: (tf.float32, tf.int32)>\n",
    "batch <BatchDataset shapes: ((?, 256, 256, 3), (?,)), types: (tf.float32, tf.int32)>\n",
    "<tensorflow.python.data.ops.iterator_ops.Iterator object at 0x7fecb8eafb38>\n",
    "count = 0 : X.shape = (256, 256, 256, 3), y[:10] = [ 1 11 12  2  3  3  2 10 13  8], pass 26.30490732192993s\n",
    "count = 1 : X.shape = (256, 256, 256, 3), y[:10] = [ 0  1 13  9 16  6 16 14  0 12], pass 1.4967591762542725s\n",
    "count = 2 : X.shape = (256, 256, 256, 3), y[:10] = [12  8 10  9 14  3  3  4 15  4], pass 1.6755847930908203s\n",
    "count = 3 : X.shape = (256, 256, 256, 3), y[:10] = [ 9 16  1 14 17  0 11 15  5  0], pass 1.2678115367889404s\n",
    "count = 4 : X.shape = (256, 256, 256, 3), y[:10] = [10 15 16  1  0  6 11 10 16  4], pass 1.2361955642700195s\n",
    "count = 5 : X.shape = (256, 256, 256, 3), y[:10] = [10 13 10 15 12 13  7  0  1 14], pass 1.4305830001831055s\n",
    "count = 6 : X.shape = (256, 256, 256, 3), y[:10] = [10  7  1 14  3 13 14 12  4  0], pass 1.262021541595459s\n",
    "count = 7 : X.shape = (256, 256, 256, 3), y[:10] = [16 15 17 19  5 14  8  6  1 16], pass 1.4019050598144531s\n",
    "count = 8 : X.shape = (256, 256, 256, 3), y[:10] = [ 5  9  9 18 12 14 14  0  3 17], pass 1.3533422946929932s\n",
    "count = 9 : X.shape = (256, 256, 256, 3), y[:10] = [ 0  8 19 14  2 10 17  9  7  2], pass 1.429227352142334s\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "大概整理了一下不同 batch_size 和 不同 buffle size 时候取 batch 的速度，下面是在机械硬盘上的实验结果，因为这台机器上没有 SSD，所以还没有试 SSD 上的速度，应该能快好多倍。\n",
    "\n",
    "|batch_size|buffer_size|启动时间(s)|每个batch(s)|\n",
    "|:----:|:---:|:---:|:---:|\n",
    "|4|0|0|0.37|\n",
    "|4|1000|2|0.37|\n",
    "|4|5000|10|0.39|\n",
    "|4|10000|20|0.39|\n",
    "|256|0|2|0.9|\n",
    "|256|1000|2|0.9|\n",
    "|256|5000|10|1.35|\n",
    "|256|10000|20|1.35(方差大)|\n",
    "\n",
    "从上面来看，第一感觉就是：**妈呀，咋这么慢呀！这个 IO 估计都比网络训练的时间还多了。**\n",
    "\n",
    "下面是采用 队列 的方式来进行数据读取，速度要快很多很多。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "loop 99, pass 11.43s"
     ]
    }
   ],
   "source": [
    "\"\"\"Use tf.data.Dataset to create dataset for image(png) data.\n",
    "With TF Queue, shuffle data\n",
    "\n",
    "refer: https://github.com/aymericdamien/TensorFlow-Examples/blob/master/examples/5_DataManagement/build_an_image_dataset.py\n",
    "\"\"\"\n",
    "\n",
    "from __future__ import print_function\n",
    "from __future__ import division\n",
    "from __future__ import absolute_import\n",
    "\n",
    "import warnings\n",
    "\n",
    "warnings.filterwarnings('ignore')  # 不打印 warning\n",
    "import tensorflow as tf\n",
    "\n",
    "# 设置GPU按需增长\n",
    "config = tf.ConfigProto()\n",
    "config.gpu_options.allow_growth = True\n",
    "sess = tf.Session(config=config)\n",
    "\n",
    "import numpy as np\n",
    "import sys\n",
    "import os\n",
    "import time\n",
    "\n",
    "\n",
    "def get_file_path(data_path='../data/sketchy_000000000000/'):\n",
    "    \"\"\"解析文件夹，获取每个文件的路径和标签。\"\"\"\n",
    "    img_paths = list()\n",
    "    labels = list()\n",
    "    class_dirs = sorted(os.listdir(data_path))\n",
    "    dict_class2id = dict()\n",
    "    for i in range(len(class_dirs)):\n",
    "        label = i\n",
    "        class_dir = class_dirs[i]\n",
    "        dict_class2id[class_dir] = label\n",
    "        class_path = os.path.join(data_path, class_dir)  # 每类的路径\n",
    "        file_names = sorted(os.listdir(class_path))\n",
    "        for file_name in file_names:\n",
    "            file_path = os.path.join(class_path, file_name)\n",
    "            img_paths.append(file_path)\n",
    "            labels.append(label)\n",
    "    return img_paths, labels\n",
    "\n",
    "\n",
    "def get_batch(img_paths, labels, batch_size=128, height=256, width=256, channel=3):\n",
    "    \"\"\"根据 img_path 读入图片并做相应处理\"\"\"\n",
    "    # 从硬盘上读取图片\n",
    "    img_paths = np.asarray(img_paths)\n",
    "    labels = np.asarray(labels)\n",
    "\n",
    "    img_paths = tf.convert_to_tensor(img_paths, dtype=tf.string)\n",
    "    labels = tf.convert_to_tensor(labels, dtype=tf.int32)\n",
    "    # Build a TF Queue, shuffle data\n",
    "    image, label = tf.train.slice_input_producer([img_paths, labels], shuffle=True)\n",
    "    # Read images from disk\n",
    "    image = tf.read_file(image)\n",
    "    image = tf.image.decode_jpeg(image, channels=channel)\n",
    "    # Resize images to a common size\n",
    "    image = tf.image.resize_images(image, [height, width])\n",
    "    # Normalize\n",
    "    image = image * 1.0 / 127.5 - 1.0\n",
    "    # Create batches\n",
    "    X_batch, y_batch = tf.train.batch([image, label], batch_size=batch_size,\n",
    "                                      capacity=batch_size * 8,\n",
    "                                      num_threads=4)\n",
    "    return X_batch, y_batch\n",
    "\n",
    "\n",
    "img_paths, labels = get_file_path()\n",
    "X_batch, y_batch = get_batch(img_paths, labels)\n",
    "\n",
    "sess.run(tf.global_variables_initializer())\n",
    "tf.train.start_queue_runners(sess=sess)\n",
    "\n",
    "time0 = time.time()\n",
    "for count in range(100):   # 11s for 100batch\n",
    "    _X_batch, _y_batch = sess.run([X_batch, y_batch])\n",
    "    sys.stdout.write(\"\\rloop {}, pass {:.2f}s\".format(count, time.time() - time0))\n",
    "    sys.stdout.flush()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "对于 png 数据的读取,我尝试了 3 组不同的方式: one-shot 方式, tf 的队列方式(queue), tfrecord 方式. 同样是在机械硬盘上操作, 结果是 tfrecord 方式明显要快一些.\n",
    "\n",
    "|iter_mode|buffer_size|100 batch(s)|\n",
    "|:----:|:---:|:---:|\n",
    "|one-shot|2000|75|\n",
    "|one-shot|5000|86|\n",
    "|tf.queue|2000|11|\n",
    "|tf.queue|5000|11|\n",
    "|tfrecord（3线程）|2000|5.3|\n",
    "|tfrecord（3线程）|5000|5.3|\n",
    "\n",
    "如果是在 SSD 上面的话,tf 的队列方式应该也是比较快的.打包成 tfrecord 格式只是减少了小文件的读取，其实现也是使用队列的。\n",
    "\n",
    "下面在 tfrecord 方式下，比较对 image 进行 resize (256 -> 224)操作，batch_size=128：\n",
    "\n",
    "|线程数|是否resize|500 batch(s)|\n",
    "|:----:|:---:|:---:|\n",
    "|2|否|35.58|\n",
    "|4|否|21.43|\n",
    "|6|否|16.35|\n",
    "|2|是|80|\n",
    "|4|是|44|\n",
    "|6|是|34|"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
